{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02163.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02163.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rD43vHrURxgR",
        "colab_type": "text"
      },
      "source": [
        "## 5.10 批量归一化"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ay3HeFy9SK5N",
        "colab_type": "text"
      },
      "source": [
        "### 5.10.2 从零开始实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fn_YYOqnST9S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch \n",
        "from torch import nn, optim \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "\n",
        "def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
        "    if not is_training:\n",
        "        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
        "    else:\n",
        "        assert len(X.shape) in (2, 4)\n",
        "        if len(X.shape) == 2:\n",
        "            mean = X.mean(dim=0)\n",
        "            var = ((X - mean) ** 2).mean(dim=0)\n",
        "        else:\n",
        "            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
        "            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
        "        X_hat = (X - mean) / torch.sqrt(var + eps)\n",
        "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean \n",
        "        moving_var = momentum * moving_var + (1.0 - momentum) * var \n",
        "    Y = gamma * X_hat + beta \n",
        "    return Y, moving_mean, moving_var "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1meq69SKX1TV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BatchNorm(nn.Module):\n",
        "    def __init__(self, num_features, num_dims):\n",
        "        super(BatchNorm, self).__init__()\n",
        "        if num_dims == 2:\n",
        "            shape = (1, num_features)\n",
        "        else:\n",
        "            shape = (1, num_features, 1, 1)\n",
        "        self.gamma = nn.Parameter(torch.ones(shape))\n",
        "        self.beta = nn.Parameter(torch.zeros(shape))\n",
        "        self.moving_mean = torch.zeros(shape)\n",
        "        self.moving_var = torch.zeros(shape)\n",
        "\n",
        "    def forward(self, X):\n",
        "        if self.moving_mean.device != X.device:\n",
        "            self.moving_mean = self.moving_mean.to(X.device)\n",
        "            self.moving_var = self.moving_var.to(X.device)\n",
        "        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta, \n",
        "                            self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9)\n",
        "        return Y "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VZNwFr7-XB6D",
        "colab_type": "text"
      },
      "source": [
        "#### 1 使用批零归一化层的 LeNet"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4F_VHWz3XbKy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net = nn.Sequential(\n",
        "    nn.Conv2d(1, 6, 5), \n",
        "    BatchNorm(6, num_dims=4), \n",
        "    nn.Sigmoid(), \n",
        "    nn.MaxPool2d(2, 2), \n",
        "    nn.Conv2d(6, 16, 5), \n",
        "    BatchNorm(16, num_dims=4), \n",
        "    nn.Sigmoid(), \n",
        "    nn.MaxPool2d(2, 2), \n",
        "    d2l.FlattenLayer(), \n",
        "    nn.Linear(16*4*4, 120), \n",
        "    BatchNorm(120, num_dims=2), \n",
        "    nn.Sigmoid(), \n",
        "    nn.Linear(120, 84), \n",
        "    BatchNorm(84, num_dims=2), \n",
        "    nn.Sigmoid(), \n",
        "    nn.Linear(84, 10), \n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6mCMR4sKYKY8",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "18233860-a07d-4a02-fdc4-bcfdeabe2dea"
      },
      "source": [
        "batch_size = 256 \n",
        "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
        "\n",
        "lr, num_epochs = 0.001, 5 \n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 0.9796, train acc 0.794, test acc 0.802, time 6.4 sec\n",
            "epoch 2, loss 0.4457, train acc 0.869, test acc 0.858, time 6.4 sec\n",
            "epoch 3, loss 0.3565, train acc 0.883, test acc 0.849, time 6.3 sec\n",
            "epoch 4, loss 0.3253, train acc 0.887, test acc 0.872, time 6.3 sec\n",
            "epoch 5, loss 0.3031, train acc 0.894, test acc 0.874, time 6.3 sec\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zJ5xH1W9Ym7Y",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "c14798de-a530-473a-cf5a-fe77c30f2924"
      },
      "source": [
        "net[1].gamma.view((-1, )), net[1].beta.view((-1, ))"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(tensor([0.9841, 0.9523, 1.0258, 1.2424, 1.1701, 1.0703], device='cuda:0',\n",
              "        grad_fn=<ViewBackward>),\n",
              " tensor([ 0.1280, -0.3909,  0.0480,  0.5539,  0.0341, -0.1078], device='cuda:0',\n",
              "        grad_fn=<ViewBackward>))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EGUEwHhPYw5Y",
        "colab_type": "text"
      },
      "source": [
        "### 5.10.3 简洁实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7ub6fpqZjl-6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net = nn.Sequential(\n",
        "    nn.Conv2d(1, 6, 5), \n",
        "    nn.BatchNorm2d(6), \n",
        "    nn.Sigmoid(), \n",
        "    nn.MaxPool2d(2, 2), \n",
        "    nn.Conv2d(6, 16, 5), \n",
        "    nn.BatchNorm2d(16), \n",
        "    nn.Sigmoid(), \n",
        "    nn.MaxPool2d(2, 2), \n",
        "    d2l.FlattenLayer(), \n",
        "    nn.Linear(16*4*4, 120), \n",
        "    nn.BatchNorm1d(120), \n",
        "    nn.Sigmoid(), \n",
        "    nn.Linear(120, 84), \n",
        "    nn.BatchNorm1d(84), \n",
        "    nn.Sigmoid(), \n",
        "    nn.Linear(84, 10), \n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OEKZ15Jkkf7Q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "0d1cde54-070a-452e-ca38-95aef5ca6041"
      },
      "source": [
        "batch_size = 256 \n",
        "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
        "\n",
        "lr, num_epochs = 0.001, 5 \n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "d2l.train_ch5(net,  train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 0.9948, train acc 0.784, test acc 0.804, time 5.8 sec\n",
            "epoch 2, loss 0.4557, train acc 0.865, test acc 0.847, time 5.8 sec\n",
            "epoch 3, loss 0.3670, train acc 0.879, test acc 0.841, time 5.8 sec\n",
            "epoch 4, loss 0.3294, train acc 0.888, test acc 0.854, time 5.8 sec\n",
            "epoch 5, loss 0.3093, train acc 0.892, test acc 0.836, time 5.8 sec\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}